/*
Copyright (c) Edgeless Systems GmbH

SPDX-License-Identifier: BUSL-1.1
*/

package measurements

import (
	"fmt"

	"github.com/edgelesssys/constellation/v2/internal/attestation/variant"
	"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
)

var measurementOverridesForCSP = map[string]measurementOverride{
	cloudprovider.AWS.String(): {
		MustEnforce: []uint32{
			4, 8, 9, 11, 12, 13, uint32(PCRIndexClusterID),
		},
		MustWarn: []uint32{
			0, 2, 3, 6, 14,
		},
		ValueOverrides: []valueOverride{
			{Index: 2, Value: []byte{0x3d, 0x45, 0x8c, 0xfe, 0x55, 0xcc, 0x03, 0xea, 0x1f, 0x44, 0x3f, 0x15, 0x62, 0xbe, 0xec, 0x8d, 0xf5, 0x1c, 0x75, 0xe1, 0x4a, 0x9f, 0xcf, 0x9a, 0x72, 0x34, 0xa1, 0x3f, 0x19, 0x8e, 0x79, 0x69}},
			{Index: 3, Value: []byte{0x3d, 0x45, 0x8c, 0xfe, 0x55, 0xcc, 0x03, 0xea, 0x1f, 0x44, 0x3f, 0x15, 0x62, 0xbe, 0xec, 0x8d, 0xf5, 0x1c, 0x75, 0xe1, 0x4a, 0x9f, 0xcf, 0x9a, 0x72, 0x34, 0xa1, 0x3f, 0x19, 0x8e, 0x79, 0x69}},
			{Index: 6, Value: []byte{0x3d, 0x45, 0x8c, 0xfe, 0x55, 0xcc, 0x03, 0xea, 0x1f, 0x44, 0x3f, 0x15, 0x62, 0xbe, 0xec, 0x8d, 0xf5, 0x1c, 0x75, 0xe1, 0x4a, 0x9f, 0xcf, 0x9a, 0x72, 0x34, 0xa1, 0x3f, 0x19, 0x8e, 0x79, 0x69}},
			{Index: 14, Value: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}},
		},
	},
	cloudprovider.Azure.String(): {
		MustEnforce: []uint32{
			4, 8, 9, 11, 12, 13, uint32(PCRIndexClusterID),
		},
		MustWarn: []uint32{
			1, 2, 3, 14,
		},
		ValueOverrides: []valueOverride{
			{Index: 1, Value: []byte{0x3d, 0x45, 0x8c, 0xfe, 0x55, 0xcc, 0x03, 0xea, 0x1f, 0x44, 0x3f, 0x15, 0x62, 0xbe, 0xec, 0x8d, 0xf5, 0x1c, 0x75, 0xe1, 0x4a, 0x9f, 0xcf, 0x9a, 0x72, 0x34, 0xa1, 0x3f, 0x19, 0x8e, 0x79, 0x69}},
			{Index: 2, Value: []byte{0x3d, 0x45, 0x8c, 0xfe, 0x55, 0xcc, 0x03, 0xea, 0x1f, 0x44, 0x3f, 0x15, 0x62, 0xbe, 0xec, 0x8d, 0xf5, 0x1c, 0x75, 0xe1, 0x4a, 0x9f, 0xcf, 0x9a, 0x72, 0x34, 0xa1, 0x3f, 0x19, 0x8e, 0x79, 0x69}},
			{Index: 3, Value: []byte{0x3d, 0x45, 0x8c, 0xfe, 0x55, 0xcc, 0x03, 0xea, 0x1f, 0x44, 0x3f, 0x15, 0x62, 0xbe, 0xec, 0x8d, 0xf5, 0x1c, 0x75, 0xe1, 0x4a, 0x9f, 0xcf, 0x9a, 0x72, 0x34, 0xa1, 0x3f, 0x19, 0x8e, 0x79, 0x69}},
			{Index: 14, Value: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}},
		},
	},
	cloudprovider.GCP.String(): {
		MustEnforce: []uint32{
			4, 8, 9, 11, 12, 13, uint32(PCRIndexClusterID),
		},
		MustWarn: []uint32{
			1, 2, 3, 6, 14,
		},
		ValueOverrides: []valueOverride{
			{Index: 1, Value: []byte{0x36, 0x95, 0xdc, 0xc5, 0x5e, 0x3a, 0xa3, 0x40, 0x27, 0xc2, 0x77, 0x93, 0xc8, 0x5c, 0x72, 0x3c, 0x69, 0x7d, 0x70, 0x8c, 0x42, 0xd1, 0xf7, 0x3b, 0xd6, 0xfa, 0x4f, 0x26, 0x60, 0x8a, 0x5b, 0x24}},
			{Index: 2, Value: []byte{0x3d, 0x45, 0x8c, 0xfe, 0x55, 0xcc, 0x03, 0xea, 0x1f, 0x44, 0x3f, 0x15, 0x62, 0xbe, 0xec, 0x8d, 0xf5, 0x1c, 0x75, 0xe1, 0x4a, 0x9f, 0xcf, 0x9a, 0x72, 0x34, 0xa1, 0x3f, 0x19, 0x8e, 0x79, 0x69}},
			{Index: 3, Value: []byte{0x3d, 0x45, 0x8c, 0xfe, 0x55, 0xcc, 0x03, 0xea, 0x1f, 0x44, 0x3f, 0x15, 0x62, 0xbe, 0xec, 0x8d, 0xf5, 0x1c, 0x75, 0xe1, 0x4a, 0x9f, 0xcf, 0x9a, 0x72, 0x34, 0xa1, 0x3f, 0x19, 0x8e, 0x79, 0x69}},
			{Index: 6, Value: []byte{0x3d, 0x45, 0x8c, 0xfe, 0x55, 0xcc, 0x03, 0xea, 0x1f, 0x44, 0x3f, 0x15, 0x62, 0xbe, 0xec, 0x8d, 0xf5, 0x1c, 0x75, 0xe1, 0x4a, 0x9f, 0xcf, 0x9a, 0x72, 0x34, 0xa1, 0x3f, 0x19, 0x8e, 0x79, 0x69}},
			{Index: 14, Value: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}},
		},
	},
	cloudprovider.OpenStack.String(): {
		MustEnforce: []uint32{
			4, 8, 9, 11, 12, 13, uint32(PCRIndexClusterID),
		},
		MustWarn: []uint32{
			14,
		},
		ValueOverrides: []valueOverride{
			{Index: 14, Value: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}},
		},
	},
}

var measurementOverridesForAttestationVariant = map[string]measurementOverride{
	variant.AWSNitroTPM{}.String(): {
		ValueOverrides: []valueOverride{
			{Index: 0, Value: []byte{0x73, 0x7f, 0x76, 0x7a, 0x12, 0xf5, 0x4e, 0x70, 0xee, 0xcb, 0xc8, 0x68, 0x40, 0x11, 0x32, 0x3a, 0xe2, 0xfe, 0x2d, 0xd9, 0xf9, 0x07, 0x85, 0x57, 0x79, 0x69, 0xd7, 0xa2, 0x01, 0x3e, 0x8c, 0x12}},
		},
	},
	variant.AWSSEVSNP{}.String(): {
		ValueOverrides: []valueOverride{
			{Index: 0, Value: []byte{0xd6, 0xdf, 0x85, 0x53, 0x58, 0xf5, 0xb1, 0x0f, 0x06, 0xf0, 0xfa, 0xb3, 0xf4, 0x08, 0xad, 0x26, 0xcd, 0x16, 0x5a, 0x29, 0x49, 0xba, 0xd6, 0x9e, 0x2c, 0xc7, 0x56, 0x92, 0x52, 0x9e, 0x66, 0x2a}},
		},
	},
}

type measurementOverride struct {
	MustEnforce    []uint32
	MustWarn       []uint32
	ValueOverrides []valueOverride
}

type valueOverride struct {
	Index uint32
	Value []byte
}

// ApplyOverrides applies overrides to the given measurements.
func ApplyOverrides(in M, csp cloudprovider.Provider, attestationVariant string) (M, error) {
	out := in.Copy()
	var matchingOverrides []measurementOverride
	if attestationVariantOverride, ok := measurementOverridesForAttestationVariant[attestationVariant]; ok {
		matchingOverrides = append(matchingOverrides, attestationVariantOverride)
	}
	if cspOverride, ok := measurementOverridesForCSP[csp.String()]; ok {
		matchingOverrides = append(matchingOverrides, cspOverride)
	}
	for _, override := range matchingOverrides {
		for _, i := range override.ValueOverrides {
			m, ok := out[i.Index]
			if !ok {
				m = Measurement{}
			}
			m.Expected = i.Value
			out[i.Index] = m
		}
		for _, i := range override.MustEnforce {
			m, ok := out[i]
			if !ok {
				return nil, fmt.Errorf("missing measurement for PCR %d", i)
			}
			m.ValidationOpt = Enforce
			out[i] = m
		}
		for _, i := range override.MustWarn {
			m, ok := out[i]
			if !ok {
				return nil, fmt.Errorf("missing measurement for PCR %d", i)
			}
			m.ValidationOpt = WarnOnly
			out[i] = m
		}
	}
	return out, nil
}
